from BasicLibraries import *
from Functions.Miscellaneous import utils as ut
import get_data as gd
from pysankey import sankey
from matplotlib.gridspec import GridSpec
from matplotlib import rc
import matplotlib.lines as mlines
import get_data as gd_

plt.rcParams['xtick.major.pad']='8'
rc('text', usetex=True)
rc('font', size=24)
rc('legend', fontsize=24)
rc('text.latex', preamble=r'\usepackage{cmbright}')
rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']})
ONEDRIVE = '/Users/simonecenci/OneDrive - Imperial College London/'

color = ['#631919', '#009c9b', '#fb7c54', '#85af38']
class Descriptive():
    def __init__(self):
        pass
    def remove_duplicated_fingerprints(self, x):
        #=== Find duplicated fingerprings
        s = x[['md5_fingerprint', 'mrg']].groupby('md5_fingerprint').count().sort_values(by = 'mrg')
        s = s[s.mrg > 1].dropna()

        #=== Remove duplicates by giving priority to downloaded files
        to_remove = []
        for i in s.index:
            tmp = x[x.md5_fingerprint == i]
            types = list(tmp.data_type.unique())
            if len(types) == 1:
                to_remove.extend(tmp['mrg'].iloc[:-1].values.tolist())
            elif len(types) > 1:
                to_remove.extend(tmp[tmp.data_type == 'crawled']['mrg'].values.tolist())
                tmp2 = tmp[tmp.data_type != 'crawled']
                types2 = list(tmp2.data_type)
                if len(types2) > 1:
                    to_remove.extend(tmp['mrg'].iloc[:-1].values.tolist())

        #=== Remove duplicates
        x = x[x.mrg.isin(to_remove) == False]
        return x
    def get_BRD(self, sdgs, diversification_measure='entropy'):
        res = pd.DataFrame()
        resN = pd.DataFrame()
        resT = pd.DataFrame()
        resS = pd.DataFrame()
        resNS = pd.DataFrame()
        count=0

        initiatives, missing_sdgs, initiatives_sdgs, initiatives_to_remove = gd.get_initiatives(sdgs)
        dt = gd.get_behaviour(initiatives_to_remove, missing_sdgs)
        dt = dt[dt.GICS_level_1.isin(['Industrial', 'Energy', 'Material', 'Utilities'])]
        dt = self.remove_duplicated_fingerprints(dt)
        
        
        post_gvkey = list(dt[dt.rfyear >= 2020].gvkey.unique())
        idx = dt[dt.rfyear < 2020][['gvkey', 'rfyear']].groupby('gvkey').count()
        idx = list(idx[idx > 3].dropna().index)
        dt = dt[dt.gvkey.isin(np.unique(post_gvkey+idx))]
        


        
        for i in initiatives:
            init = [col for col in dt.columns if i in col]
            dt[i] = dt[init].sum(axis = 1)
        for i in initiatives_sdgs:
            dt[i+'_%'] = dt[i]/dt[initiatives].sum(axis = 1).dropna()

        #===
        dt['total_effort'] = dt[initiatives_sdgs].sum(axis = 1)
        dt['mean_effort'] = dt[initiatives_sdgs].mean(axis = 1)
        dt = dt[dt.total_effort > 2]

        dt['MacroRegionShort'] = dt['MacroRegion']
        dt['MacroRegionShort'] = dt['MacroRegionShort'].replace(['Latin America and Caribbean'], ['United States and Canada'])
        dt['MacroRegionShort'] = dt['MacroRegionShort'].replace(['Africa', 'Middle-East'], ['Europe', 'Europe'])   
        dt = dt[dt.MacroRegion.isin(['Asia-Pacific', 'Europe', 'United States and Canada'])]
        dt['MacroRegionShort'] = dt['MacroRegionShort'].replace(['United States and Canada'], ['North America'])
        
        
        dt['concentration'] = dt[[i+'_%' for i in initiatives_sdgs]].std(axis = 1)
        if diversification_measure == 'concentration':
            dt['diversification'] = (1. - dt['concentration']).apply(np.log)
        elif diversification_measure == 'entropy':
            dt['diversification'] = dt.apply(lambda x: gd_.get_entropy(x, initiatives_sdgs), axis = 1)
        elif diversification_measure == 'entropy_segments':
            dt['diversification'] = dt.apply(lambda x: gd_.get_entropy_segments(x, initiatives_sdgs), axis = 1)
            dt['segs'] = dt.apply(lambda x: gd_.get_segments(x, initiatives_sdgs), axis = 1)
        elif diversification_measure == 'simpson':
            dt['diversification'] = dt.apply(lambda x: gd_.get_simpson(x, initiatives_sdgs), axis = 1)


        for S in ['Risk mitigation','Stakeholders engagement', 'Innovation']:
             ST, TE = gd_.make_strategy_diversification(dt, S, sdgs, diversification_measure, sector_normalisation = False)
             dt[S] = ST 
             dt['total_effort_'+S] = TE

        dt = dt[dt.rfyear > 2005]
        res = pd.concat((res, dt[['rfyear', 'diversification']].groupby('rfyear').mean()), axis = 1)
        resS = pd.concat((resS, 1.*dt[['rfyear', 'diversification']].groupby('rfyear').apply(ut.utils().stdErr)['diversification']), axis = 1)
        resT = pd.concat((resT, dt[['rfyear', 'total_effort']].groupby('rfyear').mean()), axis = 1)       
        
        sector_level = 'MacroRegion'
        Z = dt.copy()
        Z['MacroRegion'] = Z['MacroRegion'].replace(['Latin America and Caribbean'], ['United States and Canada'])
        Z['MacroRegion'] = Z['MacroRegion'].replace(['Africa', 'Middle-East'], ['Europe', 'Europe'])   
        Z = Z[Z.MacroRegion.isin(['Asia-Pacific', 'Europe', 'United States and Canada'])]
        Z['MacroRegion'] = Z['MacroRegion'].replace(['United States and Canada'], ['North America'])
        R = Z[['rfyear',  sector_level, 'diversification']].groupby(['rfyear',  sector_level]).mean().reset_index()
        R = R.pivot(index='rfyear',  columns=sector_level, values='diversification')
        RN = Z[['rfyear',  sector_level, 'diversification']].groupby(['rfyear',  sector_level]).apply(ut.utils().stdErr)['diversification'].reset_index()
        RN = RN.pivot(index='rfyear',  columns=sector_level, values='diversification')


        sector_level = 'GICS_level_1'
        S = dt[['rfyear',  sector_level, 'diversification']].groupby(['rfyear',  sector_level]).mean().reset_index()
        S = S.pivot(index='rfyear',  columns=sector_level, values='diversification')
        SN = dt[['rfyear',  sector_level, 'diversification']].groupby(['rfyear',  sector_level]).apply(ut.utils().stdErr)['diversification'].reset_index()
        SN = SN.pivot(index='rfyear',  columns=sector_level, values='diversification')
        

        return initiatives, missing_sdgs, initiatives_sdgs, initiatives_to_remove, dt, res, resT, resS, R, RN, S, SN
    
    
   
    
    def plot_pop_level_diversification(self, dt, res,resS,  S, SN, R, RN, save_=False):
        res.columns = ['Diversification'] 
        resS.columns = ['Diversification']

        fig = plt.figure(figsize = (18,10))
        gs = GridSpec(nrows=2, ncols=2,hspace=0.025,wspace=0.25)
        a, b, c = [plt.cm.winter,  plt.cm.autumn_r, plt.cm.bone]
        inner_colors = [a(.5), b(.5), c(.5)]
        ax0 = fig.add_subplot(gs[:, 0])
        Z = dt[['rfyear', 'Innovation', 'Risk mitigation', 'Stakeholders engagement']].groupby('rfyear').mean().round(3)
        Z = Z.rename(columns = {'Innovation': 'Innovation', 'Risk mitigation': 'Risk Mit.', 'Stakeholders engagement': 'Stake.Eng.'})
        ZE = dt[['rfyear', 'Innovation', 'Risk mitigation', 'Stakeholders engagement']].groupby('rfyear').apply(ut.utils().stdErr).round(3).drop(columns = ['rfyear'])
        ZE.columns = Z.columns
        Z = Z[['Risk Mit.', 'Stake.Eng.', 'Innovation']]
        ZE = ZE[['Risk Mit.', 'Stake.Eng.', 'Innovation']]

        res.round(3).plot(c='#e95057', ax = ax0, marker = 'o', ms=8, legend = False, 
                        yerr = resS['Diversification'], capsize = 4, alpha = 0.85)
        plt.ylim(1.55,1.95)


        ax0A = ax0.twinx()
        Z.plot(alpha = 0.75, ls = '--', ax = ax0A, yerr = ZE, marker = 'o', capsize = 4, color = inner_colors)

        plt.margins(x=0)
        #ax0.legend(loc = 'center', bbox_to_anchor = (0.35,0.95), ncol = 2)
        ax0A.legend(loc = 'center', bbox_to_anchor = (0.5,1.1), ncol = 3)
        ax0.set_xlabel('')
        ax0.set_ylabel('Response Diversity', color = '#e95057', labelpad = 30, fontsize = 34)
        ax0.text(-0.15, 1.05, 'A', transform=ax0.transAxes,
             fontsize=24, fontweight='bold', va='top', ha='right')



        ax1 = fig.add_subplot(gs[0, 1])
        S.plot(yerr = SN, capsize  = 4, marker = 'o', ms=8, color = color, alpha = 0.85, ax = ax1)
        ax1.legend(loc = 'center left', bbox_to_anchor = (1,0.5), ncol= 1, title = '')
        plt.margins(x=.01)
        plt.xlabel('')
        ax1.text(-0.1, 1.1, 'B', transform=ax1.transAxes,
             fontsize=24, fontweight='bold', va='top', ha='right')
        
        ax2 = fig.add_subplot(gs[1, 1])
        R.plot(yerr = RN, capsize  = 4, marker = 'o', ms = 8, color = ['#2122a5', '#b478cd', '#0b5e55'], alpha = 0.85, ax = ax2)
        ax2.legend(loc = 'center left', bbox_to_anchor = (1,0.5), ncol= 1, title = '')
        plt.margins(x=.01)
        plt.xlabel('')
        
        ax2.text(-0.1, 1., 'C', transform=ax2.transAxes,
             fontsize=24, fontweight='bold', va='top', ha='right')
        
        plt.tight_layout()
        if save_:
            plt.savefig('Figures/Figure2.pdf', dpi = 100, bbox_inches = 'tight')
    def firm_level_diversification_example(self, X):
        idx = X[['gvkey', 'rfyear']].groupby('gvkey').count()
        Z = X[X.gvkey.isin(idx[idx > 8].dropna().index)]
        Z = Z[(Z.total_effort > Z.total_effort.quantile(0.95)) ]
        Z = Z[['diversification', 'gvkey']].groupby('gvkey').mean().sort_values(by = 'diversification')
        Z = Z[Z.diversification > 0]
        
        return Z
    def plot_sankey_matrix_examples(self, X, undiversified, diversified, sdgs, initiatives, save_fig = False):
        #== Non diversified firm (ideally 11304)
        self.plot_sankey_mat(X[X.gvkey == undiversified], sdgs, initiatives, save_=save_fig)    
        print('Non diversified firm:', X[X.gvkey == undiversified].conml.unique())
        if save_fig:
            plt.savefig('Figures/Figure1SI_U.pdf', dpi = 100, bbox_inches = 'tight')

        #=== Diversified firm (ideally 104652)
        self.plot_sankey_mat(X[X.gvkey == diversified], sdgs, initiatives, save_=save_fig)
        print('Diversified firm:', X[X.gvkey == diversified].conml.unique())
        if save_fig:
            plt.savefig('Figures/Figure1SI_D.pdf', dpi = 100, bbox_inches = 'tight')

    def plot_full_sankey(self, X,sdgs, initiatives, save_fig = True):
        self.plot_sankey_mat(X, sdgs, initiatives, save_=save_fig)
        plt.savefig('Figures/Figure1SI_FULL.pdf', dpi = 100, bbox_inches = 'tight')


    def get_firm(self, X, idx, initiatives_sdgs, RM, SE, IN):
        M = X[X.gvkey.isin(idx)][initiatives_sdgs].sum().reset_index()
        M['A'] = M['index'].apply(lambda x: x.split(' - ')[0])
        M['A'] = M['A'].replace(['donation & funding', 'r&d investments'], ['donations', 'R\&D'])
        M = M[[0, 'A']].groupby('A').sum()[0]
        M = M.loc[RM+SE+IN]
        vals = [M.loc[RM].sum(),
                M.loc[SE].sum(),
                M.loc[IN].sum()]
        vals = pd.DataFrame(vals)[0]
        return M, vals
    def make_pie_charts_firm(self, D, idx_undiversified, idx_diversified, initiatives_sdgs, save_=False):
        plt.rcParams['xtick.major.pad']='8'
        rc('text', usetex=True)
        rc('font', size=24)
        rc('legend', fontsize=24)
        rc('text.latex', preamble=r'\usepackage{cmbright}')
        rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']})
        RM = ['adoption of standards and rules', 'assessment and measurement', 'asset modification', 'modification of procedures', 'training']
        SE = ['volunteerism', 'pricing', 'incentives', 'donations', 'communication']
        IN = ['association', 'R\&D', 'new products', 'organizational structuring']

        #==

        acecolor = '#eaeaf2'
        font_color = '#525252'
        size = 0.3
        a, b, c = [plt.cm.winter,  plt.cm.autumn_r, plt.cm.bone]
        outer_colors = [a(.1), a(.25), a(.5), a(0.6), a(0.7),
                        b(.1), b(.25), b(.5), b(0.75), b(0.95),
                        c(.1), c(.25), c(.5), c(0.75)]
        inner_colors = [a(.5), b(.5), c(.5)]
        fig, ax = plt.subplots(1, 2, figsize=(14,12))
        count = 0
        for IDX in [idx_undiversified, idx_diversified]:
            M, vals = self.get_firm(D, IDX, initiatives_sdgs, RM, SE, IN)
            ax[count].pie(M, 
                   radius=1-size, 
                   colors=outer_colors, 
                   labels=None, 
                   textprops={'color':font_color},
                   wedgeprops=dict(width=size, edgecolor='w'))
            ax[count].pie(vals, 
                   radius=1, # size=0.3
                   colors=inner_colors,
                   labels = None,
                   wedgeprops=dict(width=size, edgecolor='w'))
            count+=1
        
        R = mlines.Line2D([], [], color=inner_colors[0], marker='', ls='-', lw = 15, label='Risk Mit.')
        S = mlines.Line2D([], [], color=inner_colors[1], marker='', ls='-', lw = 15, label='Stake.Eng.')
        I = mlines.Line2D([], [], color=inner_colors[2], marker='', ls='-', lw = 15, label='Innovation')
        legend1 = plt.legend(handles=[R, S, I], loc='upper center', 
                   bbox_to_anchor=(-0.05, 1.1), ncol=3)

        R = [mlines.Line2D([], [], color=outer_colors[L], marker='', ls='-', lw = 15, label=RM[L].capitalize()) for L in range(len(RM))]
        S = [mlines.Line2D([], [], color=outer_colors[len(RM)+L], marker='', ls='-', lw = 15, label=SE[L].capitalize()) for L in range(len(SE))]
        I = [mlines.Line2D([], [], color=outer_colors[len(RM)+len(SE)+L], marker='', ls='-', lw = 15, label=IN[L].capitalize()) for L in range(len(IN))]
        legend2 = plt.legend(handles=R+S+I, loc='upper center', 
                   bbox_to_anchor=(1.5,1.1), ncol=1)
        plt.gca().add_artist(legend1)
        ax[0].text(-0.6,-1.5, 'Undiversified firm', fontsize = 28)
        ax[1].text(-0.5,-1.5, 'Diversified firm', fontsize = 28)
        if save_:
            plt.savefig('Figures/Figure1.pdf', dpi = 100, bbox_inches = 'tight')
            
    def make_pie_charts_sector(self, D, gvkeys, labels, initiatives_sdgs, save_=False):
        plt.rcParams['xtick.major.pad']='8'
        rc('text', usetex=True)
        rc('font', size=24)
        rc('legend', fontsize=24)
        rc('text.latex', preamble=r'\usepackage{cmbright}')
        rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']})
        RM = ['adoption of standards and rules', 'assessment and measurement', 'asset modification', 'modification of procedures', 'training']
        SE = ['volunteerism', 'pricing', 'incentives', 'donations', 'communication']
        IN = ['association', 'R\&D', 'new products', 'organizational structuring']
        
        #==
        
        acecolor = '#eaeaf2'
        font_color = '#525252'
        size = 0.3
        a, b, c = [plt.cm.winter,  plt.cm.autumn_r, plt.cm.bone]
        outer_colors = [a(.1), a(.25), a(.5), a(0.6), a(0.7),
                        b(.1), b(.25), b(.5), b(0.75), b(0.95),
                        c(.1), c(.25), c(.5), c(0.75)]
        inner_colors = [a(.5), b(.5), c(.5)]
        fig, ax = plt.subplots(2, 2, figsize=(14,12))
        count = 0
        count_O=0
        num_plots=1
        for IDX in gvkeys:
            if num_plots>len(labels):
                ax[count_O][count].pie(M, 
                       radius=1-size, 
                       colors='w', 
                       labels=None, 
                       textprops={'color':font_color},
                       wedgeprops=dict(width=size, edgecolor='w'))
                break
            M, vals = self.get_firm(D, IDX, initiatives_sdgs, RM, SE, IN)
            ax[count_O][count].pie(M, 
                   radius=1-size, 
                   colors=outer_colors, 
                   labels=None, 
                   textprops={'color':font_color},
                   wedgeprops=dict(width=size, edgecolor='w'))
            ax[count_O][count].pie(vals, 
                   radius=1, # size=0.3
                   colors=inner_colors,
                   labels = None,
                   wedgeprops=dict(width=size, edgecolor='w'))
            count+=1
            if count == 2: 
                count_O=1; count=0
            if count == 4: 
                count_O=2; count=0

            num_plots+=1
        ax[0][0].text(0.6,-1., labels[0], fontsize = 28)
        ax[0][1].text(-1.4,-1, labels[1], fontsize = 28)
        ax[1][0].text(0.6,-1., labels[2], fontsize = 28)
        if len(labels) > 3:
            ax[1][1].text(-1.4,-1, labels[3], fontsize = 28)
        R = mlines.Line2D([], [], color=inner_colors[0], marker='', ls='-', lw = 15, label='Risk Mit.')
        S = mlines.Line2D([], [], color=inner_colors[1], marker='', ls='-', lw = 15, label='Stake.Eng.')
        I = mlines.Line2D([], [], color=inner_colors[2], marker='', ls='-', lw = 15, label='Innovation')
        legend1 = plt.legend(handles=[R, S, I], loc='upper center', 
                   bbox_to_anchor=(-0.15, 2.3), ncol=3)
        
        R = [mlines.Line2D([], [], color=outer_colors[L], marker='', ls='-', lw = 15, label=RM[L].capitalize()) for L in range(len(RM))]
        S = [mlines.Line2D([], [], color=outer_colors[len(RM)+L], marker='', ls='-', lw = 15, label=SE[L].capitalize()) for L in range(len(SE))]
        I = [mlines.Line2D([], [], color=outer_colors[len(RM)+len(SE)+L], marker='', ls='-', lw = 15, label=IN[L].capitalize()) for L in range(len(IN))]
        legend2 = plt.legend(handles=R+S+I, loc='upper center', 
                   bbox_to_anchor=(1.8,2.1), ncol=1)
        plt.gca().add_artist(legend1)
        if save_:
            plt.savefig('Figures/Sectors_diversification.pdf', dpi = 100, bbox_inches = 'tight')


    def plot_sankey_mat(self, X, sdgs, actions, save_=False):
        bk = list(np.ravel([[a+' - SDG '+str(s) for a in actions] for s in sdgs]))
        X = X[bk]
        #== Make sankey plot
        z = X.copy()
        sk = []
        for i in z.columns[::-1]:
            action = i.split(' - ')[0]
            sdg = i.split(' - ')[1]
            repetition = z[i].sum()
            l = [[action, sdg, repetition]]
            sk.extend(l*int(repetition))

        sk = pd.DataFrame(sk, columns = ['Action', 'SDG', 'weight'])
        sk['SDG'] = sk['SDG'].apply(lambda x: int(x.split('SDG ')[1]))
        sk = sk.sort_values(by = ['Action', 'SDG'], ascending = False)
        sk['SDG'] = sk['SDG'].apply(lambda x: 'SDG '+str(x))
        sk['Action'] = sk['Action'].apply(lambda x: x.capitalize())
        sk['Action'] = sk['Action'].apply(lambda x: x.replace('R&d', 'R&D'))
        actions = list(sk.Action.unique())
        overall = actions + ['SDG '+str(s) for s in sdgs]
        cols = ['#EBEEEE', '#9D9D9D', '#002147', '#008AFF', '#00BEFA', '#B5F0E7',
                '#FFF5BE', '#EC7300', '#D24000', '#E40428','#e08b00', '#c2b300',
                '#92d600', '#08CDAE'][::-1][:len(sk.Action.unique())] + ['#08CDAE']*len(sk.SDG.unique())
        
        col_dict = {overall[i]: cols[i] for i in range(len(cols))}
       
        #==== Assign the SDG color to be the one of the most common action    
        max_mapping = sk.groupby(['Action', 'SDG']).last().reset_index().sort_values(by = ['SDG', 'weight']).groupby('SDG').last().reset_index()[['SDG', 'Action']]
        for n in max_mapping['SDG']:
            col_dict[n] = col_dict[max_mapping[max_mapping.SDG == n]['Action'].iloc[0]]
        #====
        plt.figure(figsize =(40,36))
        ax = plt.subplot()
        sankey(left = sk['Action'], 
               right = sk['SDG'], 
               aspect=140,
            fontsize=40, colorDict=col_dict, ax = ax)
        plt.tight_layout()

     
        
        
    #================================================================================
    #================================================================================
    #================================================================================
    def make_simulated_entropy(self, M, IS,save_=False):
        Z = M.copy()
        u, g, l = [], [], []
        target_initiatives_pct = [i+'_%' for i in IS]
        for _ in range(2000):
            #== You want the the total number of initiatives is fixed and it follow an empirical distribution
            #fixed_sum = np.round(np.random.gamma(2,10,1)[0])
            #fixed_sum = np.round(np.random.lognormal(3,0.5,1)[0])
            #fixed_sum = np.random.choice(Z[Z.number_of_initiatives < Z.number_of_initiatives.quantile(0.98)].number_of_initiatives)
            fixed_sum = np.random.choice(Z.number_of_initiatives)
            if fixed_sum > 2:
                if np.random.uniform() < 0.7:
                    fixed_sum/=2
            fixed_sum = np.clip(fixed_sum, Z.number_of_initiatives.min(), Z.number_of_initiatives.quantile(0.999))
            #=== select number of zeros
            z0 = np.random.choice(2, len(target_initiatives_pct))
            z1 = np.random.choice(2, len(target_initiatives_pct))
            
            zL = (z0+z1).clip(0,1)
            #=================
            a = np.random.uniform(0,4,len(target_initiatives_pct))
            a = pd.DataFrame(a).astype(int).values.ravel().tolist()
            a = a*zL
            a = a*fixed_sum/np.sum(a)
            tot=np.sum(a)
           
            b =a/np.sum(a)
            b = pd.DataFrame(b, index = target_initiatives_pct ).transpose()
            e = b.apply(lambda x: gd_.get_entropy(x, IS), axis = 1)[0]

            u.append([tot, e])

            #=================            
            a = np.random.gamma(1,1,len(target_initiatives_pct))
            a = pd.DataFrame(a).astype(int).values.ravel().tolist()
            a = a*zL
            a= a*fixed_sum/np.sum(a)
            tot=np.sum(a)

            b =a/np.sum(a)
            b = pd.DataFrame(b, index = target_initiatives_pct ).transpose()
            e = b.apply(lambda x: gd_.get_entropy(x,IS), axis = 1)[0]
            g.append([np.sum(a), e])
            
            
            #=================
            a = np.random.lognormal(0,1,len(target_initiatives_pct))
            a = pd.DataFrame(a).astype(int).values.ravel().tolist()
            a = a*zL
            a= a*fixed_sum/np.sum(a)
            tot=np.sum(a)

            b =a/np.sum(a)
            b = pd.DataFrame(b, index = target_initiatives_pct ).transpose()
            e = b.apply(lambda x: gd_.get_entropy(x, IS), axis = 1)[0]
            l.append([np.sum(a), e])    
            
        u,g,l = pd.DataFrame(u), pd.DataFrame(g),pd.DataFrame(l)
        
        fig=         plt.figure(figsize = (32,8))
        gs = GridSpec(nrows=2, ncols=3)
        #ax = plt.subplot(121)
        ax = fig.add_subplot(gs[0, 0])
        sns.regplot(u[0], u[1], ax = ax, label = r'Uniform ($\rho=$'+str(np.round(pearsonr(u[0], u[1])[0], 2))+')')
        sns.regplot(g[0], g[1], ax = ax, label = r'Gamma ($\rho=$'+str(np.round(pearsonr(g[0], g[1])[0], 2))+')')
        sns.regplot(l[0], l[1], ax = ax, label = r'Lognormal ($\rho=$'+str(np.round(pearsonr(l[0], l[1])[0], 2))+')')
        plt.xlabel(r'Total number of initiatives')
        plt.ylabel('Entropy')
        #plt.legend()
        plt.title('Simulations')
        ax1 = fig.add_subplot(gs[1, 0])
        u[0].plot.hist(ax = ax1, bins = 50, alpha = 0.65, density=True, label = r'Uniform ($\rho=$'+str(np.round(pearsonr(u[0], u[1])[0], 2))+')')
        g[0].plot.hist(ax = ax1, bins = 50, alpha = 0.55, density=True, label =  r'Gamma ($\rho=$'+str(np.round(pearsonr(g[0], g[1])[0], 2))+')')
        l[0].plot.hist(ax = ax1, bins = 50, alpha = 0.45, density=True, label = r'Lognormal ($\rho=$'+str(np.round(pearsonr(l[0], l[1])[0], 2))+')')
        plt.xlabel(r'Total number of initiatives')
        M.total_effort.plot.hist(ax = ax1, bins = 50, alpha = 0.5, density=True, label = 'Empirical distribution')
        plt.legend()
        #=== Get the empirical entropies
        diversification_measure = 'entropy'
        sdgs_options = [6,7,9,11,12,13,14,15] 
        _, _, _, _, dtA, _, _, _, _, _, _, _ = self.get_BRD(sdgs_options, diversification_measure)
        diversification_measure = 'entropy_segments'
        sdgs_options = [6,7,9,11,12,13,14,15] 
        _, _, _, _, dtB, _, _, _, _, _, _, _ = self.get_BRD(sdgs_options, diversification_measure)


        
        #ax = plt.subplot(122)
        ax = fig.add_subplot(gs[:, 1])
        M = dtA.copy()
        Xaxis = M.total_effort
        Yaxis = M.diversification
        sns.regplot(Xaxis, Yaxis, ax = ax, label = r'Linear scale ($\rho=$'+str(np.round(pearsonr(Xaxis, Yaxis)[0], 2))+')', color = '#ac1d1c')
        plt.xlabel('Total number of initiatives')
        plt.ylabel('Entropy')
        plt.legend(loc = 'center left', bbox_to_anchor = (.5, 0.055))
        ax2 = ax.twiny()
        Xaxis = M.total_effort.apply(np.log)
        Yaxis = M.diversification
        sns.regplot(Xaxis, Yaxis, color = '#666999', ax = ax2, label = r'log-scale ($\rho=$'+str(np.round(pearsonr(Xaxis, Yaxis)[0], 2))+')')
        plt.xlabel('Total number of initiatives (log)')
        plt.ylabel('Entropy')
        plt.legend(loc = 'center left', bbox_to_anchor = (0.05, 0.9))
        
        div_score=False
        if div_score:
            ax = fig.add_subplot(gs[:, 2])
            M = dtB.copy()
            Xaxis = M.total_effort
            Yaxis = M.diversification
            sns.regplot(Xaxis, Yaxis, ax = ax, label = r'Linear scale ($\rho=$'+str(np.round(pearsonr(Xaxis, Yaxis)[0], 2))+')', color = '#ac1d1c')
            plt.xlabel('Total number of initiatives')
            plt.ylabel('Diversification score')
            plt.legend(loc = 'center left', bbox_to_anchor = (.5, 0.055))
            ax2 = ax.twiny()
            Xaxis = M.total_effort.apply(np.log)
            Yaxis = M.diversification
            sns.regplot(Xaxis, Yaxis, color = '#666999', ax = ax2, label = r'log-scale ($\rho=$'+str(np.round(pearsonr(Xaxis, Yaxis)[0], 2))+')')
            plt.xlabel('Total number of initiatives (log)')
            plt.ylabel('Diversification score')
            plt.legend(loc = 'center left', bbox_to_anchor = (0.05, 0.9))
        
        
        plt.tight_layout()
        
        if save_:
            plt.savefig('Figures/Figure2_SI.pdf', dpi = 200, bbox_inches = 'tight')

        return u,g,l